# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
import os

import cv2
import numpy as np
import paddle
import yaml
from paddlenlp.utils.downloader import get_path_from_url_with_filelock

from .det_keypoint_unite_utils import argsparser
from .infer import PredictConfig  # noqa F401
from .infer import bench_log, get_test_images, print_arguments
from .keypoint_infer import KeyPointDetector
from .keypoint_postprocess import translate_to_ori_images
from .preprocess import decode_image
from .utils import get_current_memory_mb
from .visualize import visualize_pose

KEYPOINT_SUPPORT_MODELS = {
    "HigherHRNet": "keypoint_bottomup",
    "HRNet": "keypoint_topdown",
}


def predict_with_given_det(image, det_res, keypoint_detector, keypoint_batch_size, run_benchmark):
    keypoint_res = {}

    rec_images, records, det_rects = keypoint_detector.get_person_from_rect(image, det_res)

    if len(det_rects) == 0:
        keypoint_res["keypoint"] = [[], []]
        return keypoint_res

    keypoint_vector = []
    score_vector = []

    rect_vector = det_rects
    keypoint_results = keypoint_detector.predict_image(rec_images, run_benchmark, repeats=10, visual=False)
    keypoint_vector, score_vector = translate_to_ori_images(keypoint_results, np.array(records))
    keypoint_res["keypoint"] = (
        [keypoint_vector.tolist(), score_vector.tolist()] if len(keypoint_vector) > 0 else [[], []]
    )
    keypoint_res["bbox"] = rect_vector
    return keypoint_res


def topdown_unite_predict(
    detector,
    topdown_keypoint_detector,
    image_list,
    keypoint_batch_size=1,
    save_res=False,
):
    det_timer = detector.get_timer()
    store_res = []
    for i, img_file in enumerate(image_list):
        # Decode image in advance in det + pose prediction
        det_timer.preprocess_time_s.start()
        image, _ = decode_image(img_file, {})
        det_timer.preprocess_time_s.end()

        if FLAGS.run_benchmark:
            results = detector.predict_image([image], run_benchmark=True, repeats=10)

            cm, gm, gu = get_current_memory_mb()
            detector.cpu_mem += cm
            detector.gpu_mem += gm
            detector.gpu_util += gu
        else:
            results = detector.predict_image([image], visual=False)
        results = detector.filter_box(results, FLAGS.det_threshold)
        if results["boxes_num"] > 0:
            keypoint_res = predict_with_given_det(
                image,
                results,
                topdown_keypoint_detector,
                keypoint_batch_size,
                FLAGS.run_benchmark,
            )

            if save_res:
                save_name = img_file if isinstance(img_file, str) else i
                store_res.append(
                    [
                        save_name,
                        keypoint_res["bbox"],
                        [keypoint_res["keypoint"][0], keypoint_res["keypoint"][1]],
                    ]
                )
        else:
            results["keypoint"] = [[], []]
            keypoint_res = results
        if FLAGS.run_benchmark:
            cm, gm, gu = get_current_memory_mb()
            topdown_keypoint_detector.cpu_mem += cm
            topdown_keypoint_detector.gpu_mem += gm
            topdown_keypoint_detector.gpu_util += gu
        else:
            if not os.path.exists(FLAGS.output_dir):
                os.makedirs(FLAGS.output_dir)
            visualize_pose(
                img_file,
                keypoint_res,
                visual_thresh=FLAGS.keypoint_threshold,
                save_dir=FLAGS.output_dir,
            )
    if save_res:
        """
        1) store_res: a list of image_data
        2) image_data: [imageid, rects, [keypoints, scores]]
        3) rects: list of rect [xmin, ymin, xmax, ymax]
        4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
        5) scores: mean of all joint conf
        """
        with open("det_keypoint_unite_image_results.json", "w") as wf:
            json.dump(store_res, wf, indent=4)


def topdown_unite_predict_singleimage(
    detector,
    topdown_keypoint_detector,
    image,
    keypoint_batch_size=8,
    det_threshold=0.25,
):

    results = detector.predict_image([image], visual=False)
    results = detector.filter_box(results, det_threshold)
    if results["boxes_num"] > 0:
        keypoint_res = predict_with_given_det(image, results, topdown_keypoint_detector, keypoint_batch_size, False)

    else:
        results["keypoint"] = [[], []]
        keypoint_res = results
    return keypoint_res


def topdown_unite_predict_video(
    detector,
    topdown_keypoint_detector,
    camera_id,
    keypoint_batch_size=1,
    save_res=False,
):
    video_name = "output.mp4"
    if camera_id != -1:
        capture = cv2.VideoCapture(camera_id)
    else:
        capture = cv2.VideoCapture(FLAGS.video_file)
        video_name = os.path.split(FLAGS.video_file)[-1]
    # Get Video info : resolution, fps, frame count
    width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(capture.get(cv2.CAP_PROP_FPS))
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print("fps: %d, frame_count: %d" % (fps, frame_count))

    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)
    out_path = os.path.join(FLAGS.output_dir, video_name)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
    index = 0
    store_res = []
    keypoint_smoothing = KeypointSmoothing(width, height, filter_type=FLAGS.filter_type, beta=0.05)

    while 1:
        ret, frame = capture.read()
        if not ret:
            break
        index += 1
        print("detect frame: %d" % (index))

        frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        results = detector.predict_image([frame2], visual=False)
        results = detector.filter_box(results, FLAGS.det_threshold)
        if results["boxes_num"] == 0:
            writer.write(frame)
            continue

        keypoint_res = predict_with_given_det(
            frame2,
            results,
            topdown_keypoint_detector,
            keypoint_batch_size,
            FLAGS.run_benchmark,
        )

        if FLAGS.smooth and len(keypoint_res["keypoint"][0]) == 1:
            current_keypoints = np.array(keypoint_res["keypoint"][0][0])
            smooth_keypoints = keypoint_smoothing.smooth_process(current_keypoints)

            keypoint_res["keypoint"][0][0] = smooth_keypoints.tolist()

        im = visualize_pose(frame, keypoint_res, visual_thresh=FLAGS.keypoint_threshold, returnimg=True)

        if save_res:
            store_res.append(
                [
                    index,
                    keypoint_res["bbox"],
                    [keypoint_res["keypoint"][0], keypoint_res["keypoint"][1]],
                ]
            )

        writer.write(im)
        if camera_id != -1:
            cv2.imshow("Mask Detection", im)
            if cv2.waitKey(1) & 0xFF == ord("q"):
                break
    writer.release()
    print("output_video saved to: {}".format(out_path))
    if save_res:
        """
        1) store_res: a list of frame_data
        2) frame_data: [frameid, rects, [keypoints, scores]]
        3) rects: list of rect [xmin, ymin, xmax, ymax]
        4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
        5) scores: mean of all joint conf
        """
        with open("det_keypoint_unite_video_results.json", "w") as wf:
            json.dump(store_res, wf, indent=4)


class KeypointSmoothing(object):
    # The following code are modified from:
    # https://github.com/jaantollander/OneEuroFilter

    def __init__(
        self,
        width,
        height,
        filter_type,
        alpha=0.5,
        fc_d=0.1,
        fc_min=0.1,
        beta=0.1,
        thres_mult=0.3,
    ):
        super(KeypointSmoothing, self).__init__()
        self.image_width = width
        self.image_height = height
        self.threshold = (
            np.array(
                [
                    0.005,
                    0.005,
                    0.005,
                    0.005,
                    0.005,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                    0.01,
                ]
            )
            * thres_mult
        )
        self.filter_type = filter_type
        self.alpha = alpha
        self.dx_prev_hat = None
        self.x_prev_hat = None
        self.fc_d = fc_d
        self.fc_min = fc_min
        self.beta = beta

        if self.filter_type == "OneEuro":
            self.smooth_func = self.one_euro_filter
        elif self.filter_type == "EMA":
            self.smooth_func = self.ema_filter
        else:
            raise ValueError("filter type must be one_euro or ema")

    def smooth_process(self, current_keypoints):
        if self.x_prev_hat is None:
            self.x_prev_hat = current_keypoints[:, :2]
            self.dx_prev_hat = np.zeros(current_keypoints[:, :2].shape)
            return current_keypoints
        else:
            result = current_keypoints
            num_keypoints = len(current_keypoints)
            for i in range(num_keypoints):
                result[i, :2] = self.smooth(current_keypoints[i, :2], self.threshold[i], i)
            return result

    def smooth(self, current_keypoint, threshold, index):
        distance = np.sqrt(
            np.square((current_keypoint[0] - self.x_prev_hat[index][0]) / self.image_width)
            + np.square((current_keypoint[1] - self.x_prev_hat[index][1]) / self.image_height)
        )
        if distance < threshold:
            result = self.x_prev_hat[index]
        else:
            result = self.smooth_func(current_keypoint, self.x_prev_hat[index], index)

        return result

    def one_euro_filter(self, x_cur, x_pre, index):
        te = 1
        self.alpha = self.smoothing_factor(te, self.fc_d)
        dx_cur = (x_cur - x_pre) / te
        dx_cur_hat = self.exponential_smoothing(dx_cur, self.dx_prev_hat[index])

        fc = self.fc_min + self.beta * np.abs(dx_cur_hat)
        self.alpha = self.smoothing_factor(te, fc)
        x_cur_hat = self.exponential_smoothing(x_cur, x_pre)
        self.dx_prev_hat[index] = dx_cur_hat
        self.x_prev_hat[index] = x_cur_hat
        return x_cur_hat

    def ema_filter(self, x_cur, x_pre, index):
        x_cur_hat = self.exponential_smoothing(x_cur, x_pre)
        self.x_prev_hat[index] = x_cur_hat
        return x_cur_hat

    def smoothing_factor(self, te, fc):
        r = 2 * math.pi * fc * te
        return r / (r + 1)

    def exponential_smoothing(self, x_cur, x_pre, index=0):
        return self.alpha * x_cur + (1 - self.alpha) * x_pre


det_model_dir = "annotator/ppdet_hrnet/models/picodet_v2_s_320_pedestrian/"
keypoint_model_dir = "annotator/ppdet_hrnet/models/dark_hrnet_w32_256x192/"
keypoint_batch_size = 8
use_dark = True
device = "gpu"
run_mode = "paddle"
trt_min_shape = 1
trt_max_shape = 1920
trt_opt_shape = 640
trt_calib_mode = False
cpu_threads = 4
enable_mkldnn = False
det_threshold = 0.4

if not os.path.exists(det_model_dir):
    detmodel_url = (
        "https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_enhance/picodet_s_320_lcnet_pedestrian.zip"
    )
    get_path_from_url_with_filelock(detmodel_url, root_dir="annotator/ppdet_hrnet/models/")
if not os.path.exists(keypoint_model_dir):
    kptmodel_url = "https://bj.bcebos.com/v1/paddledet/models/pipeline/dark_hrnet_w32_256x192.zip"
    get_path_from_url_with_filelock(kptmodel_url, root_dir="annotator/ppdet_hrnet/models/")


class PPDetPose(object):
    def __init__(self) -> None:
        deploy_file = os.path.join(det_model_dir, "infer_cfg.yml")
        with open(deploy_file) as f:
            yml_conf = yaml.safe_load(f)
        arch = yml_conf["arch"]
        detector_func = "Detector"
        if arch == "PicoDet":
            detector_func = "DetectorPicoDet"

        self.detector = eval(detector_func)(
            det_model_dir,
            device=device,
            run_mode=run_mode,
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
            enable_mkldnn=enable_mkldnn,
            threshold=det_threshold,
        )

        self.topdown_keypoint_detector = KeyPointDetector(
            keypoint_model_dir,
            device=device,
            run_mode=run_mode,
            batch_size=keypoint_batch_size,
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
            enable_mkldnn=enable_mkldnn,
            use_dark=use_dark,
        )
        keypoint_arch = self.topdown_keypoint_detector.pred_config.arch
        assert (
            KEYPOINT_SUPPORT_MODELS[keypoint_arch] == "keypoint_topdown"
        ), "Detection-Keypoint unite inference only supports topdown models."

    def ppdet_hrnet_infer(self, image):
        # predict from image
        return topdown_unite_predict_singleimage(
            self.detector,
            self.topdown_keypoint_detector,
            image,
            keypoint_batch_size,
            det_threshold,
        )


def main():
    deploy_file = os.path.join(FLAGS.det_model_dir, "infer_cfg.yml")
    with open(deploy_file) as f:
        yml_conf = yaml.safe_load(f)
    arch = yml_conf["arch"]
    detector_func = "Detector"
    if arch == "PicoDet":
        detector_func = "DetectorPicoDet"

    detector = eval(detector_func)(
        FLAGS.det_model_dir,
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
        enable_mkldnn=FLAGS.enable_mkldnn,
        threshold=FLAGS.det_threshold,
    )

    topdown_keypoint_detector = KeyPointDetector(
        FLAGS.keypoint_model_dir,
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
        batch_size=FLAGS.keypoint_batch_size,
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
        enable_mkldnn=FLAGS.enable_mkldnn,
        use_dark=FLAGS.use_dark,
    )
    keypoint_arch = topdown_keypoint_detector.pred_config.arch
    assert (
        KEYPOINT_SUPPORT_MODELS[keypoint_arch] == "keypoint_topdown"
    ), "Detection-Keypoint unite inference only supports topdown models."

    # predict from video file or camera video stream
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
        topdown_unite_predict_video(
            detector,
            topdown_keypoint_detector,
            FLAGS.camera_id,
            FLAGS.keypoint_batch_size,
            FLAGS.save_res,
        )
    else:
        # predict from image
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
        topdown_unite_predict(
            detector,
            topdown_keypoint_detector,
            img_list,
            FLAGS.keypoint_batch_size,
            FLAGS.save_res,
        )
        if not FLAGS.run_benchmark:
            detector.det_times.info(average=True)
            topdown_keypoint_detector.det_times.info(average=True)
        else:
            mode = FLAGS.run_mode
            det_model_dir = FLAGS.det_model_dir
            det_model_info = {
                "model_name": det_model_dir.strip("/").split("/")[-1],
                "precision": mode.split("_")[-1],
            }
            bench_log(detector, img_list, det_model_info, name="Det")
            keypoint_model_dir = FLAGS.keypoint_model_dir
            keypoint_model_info = {
                "model_name": keypoint_model_dir.strip("/").split("/")[-1],
                "precision": mode.split("_")[-1],
            }
            bench_log(
                topdown_keypoint_detector,
                img_list,
                keypoint_model_info,
                FLAGS.keypoint_batch_size,
                "KeyPoint",
            )


if __name__ == "__main__":
    paddle.enable_static()
    parser = argsparser()
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ["CPU", "GPU", "XPU"], "device should be CPU, GPU or XPU"

    main()
